{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.model_selection import StratifiedKFold\n",
    "import lightgbm as lgb\n",
    "import numpy as np\n",
    "from sklearn.model_selection import GridSearchCV\n",
    "from sklearn import metrics\n",
    "from sklearn.metrics import accuracy_score\n",
    "import time\n",
    "from pyspark.sql import SparkSession\n",
    "from concurrent.futures import ThreadPoolExecutor\n",
    "from sklearn.ensemble import VotingClassifier\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "from sklearn.ensemble import BaggingClassifier\n",
    "from sklearn.neural_network import MLPClassifier\n",
    "from sklearn.svm import SVC\n",
    "import xgboost as xgb\n",
    "from catboost import CatBoostClassifier\n",
    "from sklearn.ensemble import AdaBoostClassifier\n",
    "from sklearn.ensemble import GradientBoostingClassifier\n",
    "from skdist.distribute.search import DistGridSearchCV,DistRandomizedSearchCV\n",
    "from sklearn.ensemble import BaggingClassifier,RandomForestClassifier,ExtraTreesClassifier,GradientBoostingClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import os\n",
    "# os.environ['JAVA_HOME'] = '/home/carl-hui/software/jdk-13.0.1'  # 这里的路径为java的bin目录所在路径"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_data = pd.read_csv('data_demo4.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_all_data=all_data[:28000]\n",
    "test_all_data=all_data[28000:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/carl-hui/software/env/AI/lib/python3.5/site-packages/ipykernel_launcher.py:1: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy\n",
      "  \"\"\"Entry point for launching an IPython kernel.\n"
     ]
    }
   ],
   "source": [
    "train_all_data[\"type\"] = train_all_data[\"type\"].map({'围网':0,'刺网':1,'拖网':2})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = train_all_data.drop('type', axis=1)\n",
    "y_train = train_all_data[\"type\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "#X_train, X_test, y_train, y_test = train_test_split(X_data, Y_data, test_size=0.3, random_state=42)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train():\n",
    "    spark = SparkSession.builder.getOrCreate()\n",
    "    sc = spark.sparkContext \n",
    "    \n",
    "    param_grid={\n",
    "        'learning_rate': np.arange(0.1,0.7), \n",
    "        \"min_child_samples\":np.arange(7,10),\n",
    "        \"max_depth\":np.arange(7,10),\n",
    "        'boosting':[\"gbdt\",\"dart\",\"goss\"], \n",
    "        \"n_estimators\":np.arange(2000,3000,100),\n",
    "    }\n",
    "    \n",
    "    voting_clf = lgb.LGBMClassifier(#num_leaves=31, 每个基学习器的最大叶子节点 default=31\n",
    "                                    learning_rate=0.01,\n",
    "                                     min_child_samples=9, #它的值取决于训练数据的样本个树和num_leaves. 将其设置的较大可以避免生成一个过深的树, 但有可能导致欠拟合。\n",
    "                                     max_depth=9, #设置树深度，深度越大可能过拟合\n",
    "                                     lambda_l1=2,boosting=\"gbdt\",objective=\"multiclass\",\n",
    "                                     n_estimators=2000,metric='multi_error',num_class=3,\n",
    "                                     feature_fraction=0.75,bagging_fraction=0.85,seed=99,\n",
    "                                     # verbose=True, default=True\n",
    "                                     # n_jobs=-1, 默认是-1\n",
    "                                     device=\"gpu\",gpu_platform_id=0,\n",
    "                                     gpu_device_id=0)\n",
    "    # 随机搜索超参数\n",
    "    gsearch =DistRandomizedSearchCV(voting_clf, param_grid, sc=sc, cv=5, \n",
    "                                scoring=\"f1_weighted\",verbose=True,\n",
    "                                n_iter = 100, #训练300次，数值越大，获得的参数精度越大，但是搜索时间越长\n",
    "                                n_jobs = -1)\n",
    "    gsearch.fit(X_train,y_train)\n",
    "    return gsearch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Spark context found; running with spark\n",
      "Fitting 5 folds for each of 100 candidates, totalling 500 fits\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-9-1560b909936b>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mstart\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0mend\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-8-d1df5fb8dea9>\u001b[0m in \u001b[0;36mtrain\u001b[0;34m()\u001b[0m\n\u001b[1;32m     27\u001b[0m                                 \u001b[0mn_iter\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m100\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;31m#训练300次，数值越大，获得的参数精度越大，但是搜索时间越长\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     28\u001b[0m                                 n_jobs = -1)\n\u001b[0;32m---> 29\u001b[0;31m     \u001b[0mgsearch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0my_train\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     30\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mgsearch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/software/env/AI/lib/python3.5/site-packages/skdist/distribute/search.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, groups, **fit_params)\u001b[0m\n\u001b[1;32m    367\u001b[0m             \u001b[0mbase_estimator_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbroadcast\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbase_estimator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    368\u001b[0m             \u001b[0mpartitions\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_parse_partitions\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpartitions\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfit_sets\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 369\u001b[0;31m             out = self.sc.parallelize(fit_sets, numSlices=partitions).map(lambda x: [x[0], _fit_and_score(\n\u001b[0m\u001b[1;32m    370\u001b[0m                 \u001b[0mbase_estimator_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscorers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    371\u001b[0m                 \u001b[0mverbose\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfit_params\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/software/env/AI/lib/python3.5/site-packages/pyspark/rdd.py\u001b[0m in \u001b[0;36mcollect\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    814\u001b[0m         \"\"\"\n\u001b[1;32m    815\u001b[0m         \u001b[0;32mwith\u001b[0m \u001b[0mSCCallSiteSync\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontext\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mcss\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 816\u001b[0;31m             \u001b[0msock_info\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_jvm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mPythonRDD\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcollectAndServe\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_jrdd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrdd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    817\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_load_from_socket\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msock_info\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_jrdd_deserializer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    818\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/software/env/AI/lib/python3.5/site-packages/py4j/java_gateway.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args)\u001b[0m\n\u001b[1;32m   1253\u001b[0m             \u001b[0mproto\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mEND_COMMAND_PART\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1254\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1255\u001b[0;31m         \u001b[0manswer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgateway_client\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msend_command\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcommand\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1256\u001b[0m         return_value = get_return_value(\n\u001b[1;32m   1257\u001b[0m             answer, self.gateway_client, self.target_id, self.name)\n",
      "\u001b[0;32m~/software/env/AI/lib/python3.5/site-packages/py4j/java_gateway.py\u001b[0m in \u001b[0;36msend_command\u001b[0;34m(self, command, retry, binary)\u001b[0m\n\u001b[1;32m    983\u001b[0m         \u001b[0mconnection\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_connection\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    984\u001b[0m         \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 985\u001b[0;31m             \u001b[0mresponse\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mconnection\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msend_command\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcommand\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    986\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mbinary\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    987\u001b[0m                 \u001b[0;32mreturn\u001b[0m \u001b[0mresponse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_create_connection_guard\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconnection\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/software/env/AI/lib/python3.5/site-packages/py4j/java_gateway.py\u001b[0m in \u001b[0;36msend_command\u001b[0;34m(self, command)\u001b[0m\n\u001b[1;32m   1150\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1151\u001b[0m         \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1152\u001b[0;31m             \u001b[0manswer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msmart_decode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstream\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreadline\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1153\u001b[0m             \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Answer received: {0}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0manswer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1154\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0manswer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstartswith\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mproto\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mRETURN_MESSAGE\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/lib/python3.5/socket.py\u001b[0m in \u001b[0;36mreadinto\u001b[0;34m(self, b)\u001b[0m\n\u001b[1;32m    573\u001b[0m         \u001b[0;32mwhile\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    574\u001b[0m             \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 575\u001b[0;31m                 \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_sock\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrecv_into\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    576\u001b[0m             \u001b[0;32mexcept\u001b[0m \u001b[0mtimeout\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    577\u001b[0m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_timeout_occurred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/software/env/AI/lib/python3.5/site-packages/pyspark/context.py\u001b[0m in \u001b[0;36msignal_handler\u001b[0;34m(signal, frame)\u001b[0m\n\u001b[1;32m    268\u001b[0m         \u001b[0;32mdef\u001b[0m \u001b[0msignal_handler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msignal\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mframe\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    269\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcancelAllJobs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 270\u001b[0;31m             \u001b[0;32mraise\u001b[0m \u001b[0mKeyboardInterrupt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    271\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    272\u001b[0m         \u001b[0;31m# see http://stackoverflow.com/questions/23206787/\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "start = time.time()\n",
    "\n",
    "model = train()\n",
    "\n",
    "end = time.time()\n",
    "print(str(end-start))   #1704.778949022293  1753.3879935741425"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(model.best_params_)\n",
    "print(model.best_score_)\n",
    "print(model.best_estimator_)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plabels = model.predict(X_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('oof f1', metrics.f1_score(y_train,plabels, average='macro'))\n",
    "print('train accuracy', accuracy_score(y_train,plabels))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "生产"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sub = test_all_data[['ID']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "production_pred = model.predict(test_all_data.drop([\"type\"], axis=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sub['pred'] = production_pred\n",
    "\n",
    "print(sub['pred'].value_counts(1))\n",
    "sub['pred'] = sub['pred'].map({0:'围网',1:'刺网',2:'拖网'})\n",
    "sub.to_csv('data/result_search.csv', index=None, header=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
